Conversation
Signed-off-by: Jianbing Dong <jianbingd@nvidia.com>
|
@Jianbing-D can you change this to "ready to review" and fix the PR name? |
|
Hi @guyueh1 , please review it. |
📝 WalkthroughWalkthroughThis PR adds FP8 weight quantization support to the generation backend. It introduces FP8 casting utilities for block-wise quantization, integrates them into the policy worker to quantize eligible weights during export, and refactors weight loading logic in the vLLM backend to handle pre-quantized batches. Changes
Sequence Diagram(s)sequenceDiagram
participant PW as MegatronPolicyWorker
participant FP8 as fp8_train_utils
participant Export as Export Process
PW->>PW: _is_fp8_weights_enabled()
alt FP8 Enabled
PW->>PW: _iter_params_with_optional_kv_scales()
loop For Each Parameter
PW->>FP8: should_quantize_to_fp8(name, tensor)
FP8-->>PW: bool (eligible 2D weights?)
alt Eligible for Quantization
PW->>FP8: cast_tensor_to_fp8_blockwise(tensor, block_size)
FP8->>FP8: Pad to block multiples
FP8->>FP8: Compute per-block scales
FP8->>FP8: Cast to float8_e4m3fn
FP8-->>PW: (fp8_data, scale_inv)
PW->>Export: Yield (name, fp8_data)
PW->>Export: Yield (name_scale_inv, scale)
else Not Eligible
PW->>Export: Yield (name, original_tensor)
end
end
else FP8 Disabled
PW->>Export: Yield parameters unchanged
end
sequenceDiagram
participant Backend as vllm_backend
participant Config as vllm_config
participant FP8Module as fp8 module
participant Model as Model
Backend->>Backend: update_weights_via_ipc_zmq(weights)
Backend->>Backend: _load_model_weights(weights, model_runner)
Backend->>Config: is_fp8_model(vllm_config)?
alt FP8 Model
Backend->>Backend: Detect pre-quantized (_scale_inv present)
alt Pre-quantized Weights
Backend->>Model: load_weights(weights)
Model-->>Backend: Loaded
else Non Pre-quantized
Backend->>FP8Module: load_weights(weights, model_runner)
FP8Module->>Model: Apply FP8 transformations
FP8Module-->>Backend: Loaded
end
else Non-FP8 Model
Backend->>Model: load_weights(weights)
Model-->>Backend: Loaded
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip CodeRabbit can use OpenGrep to find security vulnerabilities and bugs across 17+ programming languages.OpenGrep is compatible with Semgrep configurations. Add an |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (4)
nemo_rl/models/generation/vllm/vllm_backend.py (1)
234-234: Replace lambda with local function per style guidelines.Static analysis flagged E731: assigning a lambda expression. Use a
defstatement instead.♻️ Suggested fix
- load_model_weight_func = lambda x: self._load_model_weights(x, self.model_runner) + def load_model_weight_func(x): + self._load_model_weights(x, self.model_runner)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/models/generation/vllm/vllm_backend.py` at line 234, Replace the lambda assignment to load_model_weight_func with a local def function to satisfy style rules: create a small local function (e.g., def load_model_weight_func(path): return self._load_model_weights(path, self.model_runner)) and use that function in place of the lambda; keep the same name load_model_weight_func and ensure it calls self._load_model_weights with the same arguments (path and self.model_runner) so behavior does not change.nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py (3)
57-62: Consider zero-padding instead of edge-value padding.Using
data_hp[-1, -1]for padding could propagate anomalous values (NaN, Inf, or outliers) if they happen to be at the tensor edge. Zero-padding is more conventional and predictable for quantization.🔧 Suggested fix
data_hp = torch.nn.functional.pad( - data_hp, (0, pad1, 0, pad0), mode="constant", value=data_hp[-1, -1] + data_hp, (0, pad1, 0, pad0), mode="constant", value=0.0 )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py` around lines 57 - 62, Replace edge-value padding with zero-padding in the block-alignment logic: when computing pad0/pad1 and calling torch.nn.functional.pad on data_hp (the block-padding branch that checks data_hp.shape[...] % block_size...), change the pad fill from value=data_hp[-1, -1] to value=0 so that padding uses zeros instead of the tensor's edge element; keep the same pad tuple order and mode="constant" so only the fill value changes.
69-69: Undocumented square-block constraint.The function signature accepts
weight_block_size: list[int]suggesting arbitrary block dimensions, but this assertion enforces square blocks. Consider documenting this restriction in the docstring or simplifying the signature if non-square blocks are never intended.📝 Document the restriction in docstring
Args: data_hp: 2-D high-precision weight tensor (any float dtype). - weight_block_size: [block_rows, block_cols], e.g. [128, 128]. + weight_block_size: [block_rows, block_cols], e.g. [128, 128]. + Note: block_rows must equal block_cols (square blocks only). use_pow2_scale: If True, round scale factors to powers of two.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py` at line 69, The code currently asserts block_size0 == block_size1 (from the assert block_size0 == block_size1) while the function accepts weight_block_size: list[int], which implies non-square blocks are allowed; update the function’s docstring (the function that takes weight_block_size) to state explicitly that only square blocks are supported (i.e., weight_block_size must be [N, N]) or, if non-square blocks are never intended, simplify the signature to accept a single int for block_size and remove the list ambiguity; reference the weight_block_size parameter and the assert block_size0 == block_size1 when documenting the constraint.
52-52: Consider using explicit exception instead ofassertfor input validation.Assertions can be disabled with Python's
-Oflag, which would bypass this validation in optimized runs. For production code, explicit exceptions are more reliable.🔧 Suggested fix
- assert len(data_hp.shape) == 2, "Only 2-D input tensor is supported" + if len(data_hp.shape) != 2: + raise ValueError("Only 2-D input tensor is supported")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py` at line 52, Replace the runtime-only assert on data_hp with an explicit input validation that always runs: check if len(data_hp.shape) != 2 and raise a ValueError with the message "Only 2-D input tensor is supported"; update the check around the data_hp usage in fp8_train_utils.py (the line currently doing `assert len(data_hp.shape) == 2, ...`) to this explicit conditional to ensure validation cannot be skipped under -O.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py`:
- Around line 87-91: The scale computation doesn't guard against NaN in max_abs
which will produce NaN scale_fp and downstream fp8_data; update the branch that
computes scale_fp/descale_fp to also replace NaNs (e.g., using torch.isnan or
torch.isfinite) with a safe fallback (1.0) before taking the reciprocal so that
scale_fp = max_dtype / max_abs is followed by handling max_abs == 0, max_abs ==
inf, and max_abs == NaN (set those scale entries to 1.0), then compute
descale_fp = torch.reciprocal(scale_fp); modify the existing scale_fp and
descale_fp logic where those symbols are defined to include the NaN check.
---
Nitpick comments:
In `@nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py`:
- Around line 57-62: Replace edge-value padding with zero-padding in the
block-alignment logic: when computing pad0/pad1 and calling
torch.nn.functional.pad on data_hp (the block-padding branch that checks
data_hp.shape[...] % block_size...), change the pad fill from value=data_hp[-1,
-1] to value=0 so that padding uses zeros instead of the tensor's edge element;
keep the same pad tuple order and mode="constant" so only the fill value
changes.
- Line 69: The code currently asserts block_size0 == block_size1 (from the
assert block_size0 == block_size1) while the function accepts weight_block_size:
list[int], which implies non-square blocks are allowed; update the function’s
docstring (the function that takes weight_block_size) to state explicitly that
only square blocks are supported (i.e., weight_block_size must be [N, N]) or, if
non-square blocks are never intended, simplify the signature to accept a single
int for block_size and remove the list ambiguity; reference the
weight_block_size parameter and the assert block_size0 == block_size1 when
documenting the constraint.
- Line 52: Replace the runtime-only assert on data_hp with an explicit input
validation that always runs: check if len(data_hp.shape) != 2 and raise a
ValueError with the message "Only 2-D input tensor is supported"; update the
check around the data_hp usage in fp8_train_utils.py (the line currently doing
`assert len(data_hp.shape) == 2, ...`) to this explicit conditional to ensure
validation cannot be skipped under -O.
In `@nemo_rl/models/generation/vllm/vllm_backend.py`:
- Line 234: Replace the lambda assignment to load_model_weight_func with a local
def function to satisfy style rules: create a small local function (e.g., def
load_model_weight_func(path): return self._load_model_weights(path,
self.model_runner)) and use that function in place of the lambda; keep the same
name load_model_weight_func and ensure it calls self._load_model_weights with
the same arguments (path and self.model_runner) so behavior does not change.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d17786db-f935-4e85-9778-fd41cce69b35
📒 Files selected for processing (3)
nemo_rl/models/generation/vllm/quantization/fp8_train_utils.pynemo_rl/models/generation/vllm/vllm_backend.pynemo_rl/models/policy/workers/megatron_policy_worker.py
| else: | ||
| scale_fp = max_dtype / max_abs | ||
| scale_fp = torch.where(max_abs == 0, 1.0, scale_fp) | ||
| scale_fp = torch.where(max_abs == torch.inf, 1.0, scale_fp) | ||
| descale_fp = torch.reciprocal(scale_fp) |
There was a problem hiding this comment.
NaN values are not handled in scale computation.
The linear scale path handles max_abs == 0 and max_abs == inf, but NaN values would propagate silently. If any block contains NaN, both scale_fp and the resulting fp8_data would be NaN.
🛡️ Suggested fix to handle NaN
else:
scale_fp = max_dtype / max_abs
scale_fp = torch.where(max_abs == 0, 1.0, scale_fp)
scale_fp = torch.where(max_abs == torch.inf, 1.0, scale_fp)
+ scale_fp = torch.where(torch.isnan(max_abs), 1.0, scale_fp)
descale_fp = torch.reciprocal(scale_fp)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py` around lines
87 - 91, The scale computation doesn't guard against NaN in max_abs which will
produce NaN scale_fp and downstream fp8_data; update the branch that computes
scale_fp/descale_fp to also replace NaNs (e.g., using torch.isnan or
torch.isfinite) with a safe fallback (1.0) before taking the reciprocal so that
scale_fp = max_dtype / max_abs is followed by handling max_abs == 0, max_abs ==
inf, and max_abs == NaN (set those scale entries to 1.0), then compute
descale_fp = torch.reciprocal(scale_fp); modify the existing scale_fp and
descale_fp logic where those symbols are defined to include the NaN check.
| if not name.endswith(".weight"): | ||
| return False | ||
| lower = name.lower() | ||
| if any(kw in lower for kw in ("norm", "embed", "lm_head")): |
There was a problem hiding this comment.
I think this is a bit too hacky; is it possible to obtain the list of param names to-be-quantized from the is_fp8_weight function in vllm side? This info can be synced one time and reused for all consequent steps
Signed-off-by: Jianbing Dong <jianbingd@nvidia.com>
What does this PR do ?
Quantize before weight transfer to accelerate refit in FP8 GRPO.
Describe here: https://nvbugspro.nvidia.com/bug/5863778
weights are periodically synced from the Megatron training worker to the vLLM generation worker. FP8 refit quantizes weights on the training side before broadcast, reducing network payload (BF16 -> FP8 + compact scales).
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Test results
Verified on llama3.1-8B.
https://wandb.ai/nv-default-onboard/nemo-rl/reports/-FP8-Refit-Optimization--VmlldzoxNjI2NTkwMg
Summary by CodeRabbit
Release Notes
New Features
Refactor